{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Dash Chatbot Colab Demo.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyM4OlVzQCyJ+CJavl+2kDkU"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8n6CYTIja4hq",
        "colab_type": "text"
      },
      "source": [
        "To start this Jupyter Dash app, please run all the cells below. Then, click on the **temporary** URL at the end of the last cell to open the app."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pPEkJcCxSqxs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components transformers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wNQanZGJcJhx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import time\n",
        "\n",
        "import dash\n",
        "import dash_html_components as html\n",
        "import dash_core_components as dcc\n",
        "import dash_bootstrap_components as dbc\n",
        "from dash.dependencies import Input, Output, State\n",
        "from jupyter_dash import JupyterDash\n",
        "from transformers import AutoModelWithLMHead, AutoTokenizer\n",
        "import torch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vmYDDzbccMX0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "print(f\"Device: {device}\")\n",
        "\n",
        "print(\"Start loading model...\")\n",
        "name = \"microsoft/DialoGPT-large\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(name)\n",
        "model = AutoModelWithLMHead.from_pretrained(name)\n",
        "\n",
        "# Switch to cuda, eval mode, and FP16 for faster inference\n",
        "if device == \"cuda\":\n",
        "    model = model.half()\n",
        "model.to(device)\n",
        "model.eval()\n",
        "\n",
        "print(\"Done.\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QXEtdeyicNuF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def textbox(text, box=\"other\"):\n",
        "    style = {\n",
        "        \"max-width\": \"55%\",\n",
        "        \"width\": \"max-content\",\n",
        "        \"padding\": \"10px 15px\",\n",
        "        \"border-radius\": \"25px\",\n",
        "    }\n",
        "\n",
        "    if box == \"self\":\n",
        "        style[\"margin-left\"] = \"auto\"\n",
        "        style[\"margin-right\"] = 0\n",
        "\n",
        "        color = \"primary\"\n",
        "        inverse = True\n",
        "\n",
        "    elif box == \"other\":\n",
        "        style[\"margin-left\"] = 0\n",
        "        style[\"margin-right\"] = \"auto\"\n",
        "\n",
        "        color = \"light\"\n",
        "        inverse = False\n",
        "\n",
        "    else:\n",
        "        raise ValueError(\"Incorrect option for `box`.\")\n",
        "\n",
        "    return dbc.Card(text, style=style, body=True, color=color, inverse=inverse)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vGiS2sIPcUIt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "conversation = html.Div(\n",
        "    style={\n",
        "        \"width\": \"80%\",\n",
        "        \"max-width\": \"800px\",\n",
        "        \"height\": \"70vh\",\n",
        "        \"margin\": \"auto\",\n",
        "        \"overflow-y\": \"auto\",\n",
        "    },\n",
        "    id=\"display-conversation\",\n",
        ")\n",
        "\n",
        "controls = dbc.InputGroup(\n",
        "    style={\"width\": \"80%\", \"max-width\": \"800px\", \"margin\": \"auto\"},\n",
        "    children=[\n",
        "        dbc.Input(id=\"user-input\", placeholder=\"Write to the chatbot...\", type=\"text\"),\n",
        "        dbc.InputGroupAddon(dbc.Button(\"Submit\", id=\"submit\"), addon_type=\"append\",),\n",
        "    ],\n",
        ")\n",
        "\n",
        "\n",
        "# Define app\n",
        "app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])\n",
        "server = app.server\n",
        "\n",
        "\n",
        "# Define Layout\n",
        "app.layout = dbc.Container(\n",
        "    fluid=True,\n",
        "    children=[\n",
        "        html.H1(\"Dash Chatbot (with DialoGPT)\"),\n",
        "        html.Hr(),\n",
        "        dcc.Store(id=\"store-conversation\", data=\"\"),\n",
        "        conversation,\n",
        "        controls,\n",
        "    ],\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9n6HT0CIcWJ9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@app.callback(\n",
        "    Output(\"display-conversation\", \"children\"), [Input(\"store-conversation\", \"data\")]\n",
        ")\n",
        "def update_display(chat_history):\n",
        "    return [\n",
        "        textbox(x, box=\"self\") if i % 2 == 0 else textbox(x, box=\"other\")\n",
        "        for i, x in enumerate(chat_history.split(tokenizer.eos_token)[:-1])\n",
        "    ]\n",
        "\n",
        "\n",
        "@app.callback(\n",
        "    [Output(\"store-conversation\", \"data\"), Output(\"user-input\", \"value\")],\n",
        "    [Input(\"submit\", \"n_clicks\"), Input(\"user-input\", \"n_submit\")],\n",
        "    [State(\"user-input\", \"value\"), State(\"store-conversation\", \"data\")],\n",
        ")\n",
        "def run_chatbot(n_clicks, n_submit, user_input, chat_history):\n",
        "    if n_clicks == 0:\n",
        "        return \"\", \"\"\n",
        "\n",
        "    if user_input is None or user_input == \"\":\n",
        "        return chat_history, \"\"\n",
        "\n",
        "    # # temporary\n",
        "    # return chat_history + user_input + \"<|endoftext|>\" + user_input + \"<|endoftext|>\", \"\"\n",
        "\n",
        "    # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
        "    bot_input_ids = tokenizer.encode(\n",
        "        chat_history + user_input + tokenizer.eos_token, return_tensors=\"pt\"\n",
        "    ).to(device)\n",
        "\n",
        "    # generated a response while limiting the total chat history to 1000 tokens,\n",
        "    chat_history_ids = model.generate(\n",
        "        bot_input_ids, max_length=1024, pad_token_id=tokenizer.eos_token_id\n",
        "    )\n",
        "    chat_history = tokenizer.decode(chat_history_ids[0])\n",
        "\n",
        "    return chat_history, \"\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tgwQH_TXckgN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "app.run_server(mode='external')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}